from typing import Tuple

import numpy as np
from sklearn.metrics import mean_squared_error, r2_score

from tqdm import tqdm

from config import Config


def grad(
    X: np.ndarray,
    y: np.ndarray,
    ws: np.ndarray,
    regularization: bool,
    lambda_: float
) -> np.ndarray:
    """
    Compute the gradient of the loss function for all w
    :param lambda_: regularization coef
    :param regularization:
    :param X: the data of shape (n_clients, data_dim)
    :param y: the label of shape (n_clients, 1)
    :param ws: the parameters of shape (n_clients, data_dim)
    :return: the gradient of shape (n_clients, data_dim)
    """
    if not regularization:
        return (
            np.sum(ws * X, axis=-1, keepdims=True) - y
        ) * X
    else:
        return (
            np.sum(ws * X, axis=-1, keepdims=True) - y
        ) * X + lambda_ * np.sum(ws/(1+ws**2)**2, axis=-1, keepdims=True)


def evaluate(
    test_data: np.ndarray,
    test_label: np.ndarray,
    w: np.ndarray,
    config: Config,
):
    """
    :param test_data: test data of shape (n_clients, n_test_samples, data_dim)
    :param test_label: test label of shape (n_clients, n_test_samples, 1)
    :param w: global model of shape (data_dim,)
    :return: the test loss (MSE) and the gradient norm
    """
    n_clients, n_test_samples, data_dim = test_data.shape

    # compute loss
    pred = np.dot(test_data, w) # shape (n_clients, n_test_samples)
    loss = mean_squared_error(
        pred.flatten(),
        test_label.flatten(),
    )

    # compute grad norm
    A = np.mean(
        1/n_test_samples * np.matmul(
            np.transpose(test_data, [0, -1, 1]),
            test_data
        ),
        axis=0,
    )
    b = np.mean(
        test_label * test_data,
        axis=(0, 1)
    )
    grad = np.dot(A, w) - b
    if config.regularization:
        grad += config.lambda_ * (w/(1+w**2)**2)
    grad_norm = np.linalg.norm(grad) ** 2

    return loss, grad_norm


def local_sgd(
    config: Config,
    train_data: np.ndarray,
    train_label: np.ndarray,
    test_data: np.ndarray,
    test_label: np.ndarray,
    w_0: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """

    :param config: configuration object
    :param train_data: np array of shape (n_clients, n_train_samples, data_dim)
    :param train_label: np array of shape (n_clients, n_train_samples, 1)
    :param test_data: np array of shape (n_clients, n_test_samples, data_dim)
    :param test_label: np array of shape (n_clients, n_test_samples, 1)
    :param w_0: initial model of shape (data_dim,)
    :return:
        loss of shape (n_communication+1,)
        gradient norm of shape (n_communication+1,)
        w_hist of shape (n_communication+1, data_dim)
    """
    n_clients, n_train_samples, data_dim = train_data.shape

    n_communications = n_train_samples // config.local_steps
    training_data_indice = 0

    w = w_0.copy()

    w_hist = np.zeros((n_communications + 1, data_dim))
    w_hist[0] = w

    loss = np.zeros(n_communications + 1)
    grad_norm = np.zeros(n_communications + 1)
    loss[0], grad_norm[0] = evaluate(
        test_data,
        test_label,
        w,
        config,
    )

    for t in tqdm(range(n_communications)):
        ws = np.array([w.copy() for _ in range(n_clients)])

        for k in range(config.local_steps):
            gradient = grad(train_data[:, training_data_indice, :],
                            train_label[:, training_data_indice, :],
                            ws, config.regularization, config.lambda_)
            ws -= config.local_lr * gradient

            # update the training data indice
            training_data_indice += 1

        # communicate
        w = np.mean(ws, axis=0)
        loss[t + 1], grad_norm[t + 1] = evaluate(
            test_data,
            test_label,
            w,
            config,
        )
        w_hist[t + 1] = w

    return loss, grad_norm, w_hist


def minibatch_sgd(
    config: Config,
    train_data: np.ndarray,
    train_label: np.ndarray,
    test_data: np.ndarray,
    test_label: np.ndarray,
    w_0: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """

    :param config: configuration
    :param train_data: training data of shape (n_clients, n_train_samples, data_dim)
    :param train_label: training label of shape (n_clients, n_train_samples, 1)
    :param test_data: test data of shape (n_clients, n_test_samples, data_dim)
    :param test_label: test label of shape (n_clients, n_test_samples, 1)
    :param w_0: initial model of shape (data_dim,)
    :return:
        loss of shape (n_communication+1,)
        gradient norm of shape (n_communication+1,)
        w_hist of shape (n_communication+1, data_dim)
    """
    n_clients, n_train_samples, data_dim = train_data.shape

    n_communications = n_train_samples // config.local_steps
    training_data_indice = 0

    w = w_0.copy()

    w_hist = np.zeros((n_communications + 1, data_dim))
    w_hist[0] = w

    loss = np.zeros(n_communications + 1)
    grad_norm = np.zeros(n_communications + 1)
    loss[0], grad_norm[0] = evaluate(
        test_data,
        test_label,
        w,
        config,
    )

    for t in tqdm(range(n_communications)):
        client_grad = np.zeros((n_clients, data_dim))

        for k in range(config.local_steps):
            gradient = grad(train_data[:, training_data_indice, :],
                            train_label[:, training_data_indice, :],
                            w, config.regularization, config.lambda_)
            client_grad += gradient

            # Update the training data indice
            training_data_indice += 1

        w -= config.global_lr * np.mean(client_grad, axis=0) / config.local_steps
        loss[t + 1], grad_norm[t + 1] = evaluate(
            test_data,
            test_label,
            w,
            config,
        )
        w_hist[t + 1] = w

    return loss, grad_norm, w_hist


def local_sgd_momentum(
    config: Config,
    train_data: np.ndarray,
    train_label: np.ndarray,
    test_data: np.ndarray,
    test_label: np.ndarray,
    w_0: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    :param config: configuration
    :param train_data: training data of shape (n_clients, n_train_samples, data_dim)
    :param train_label: training label of shape (n_clients, n_train_samples, 1)
    :param test_data: test data of shape (n_clients, n_test_samples, data_dim)
    :param test_label: test label of shape (n_clients, n_test_samples, 1)
    :param w_0: initial model of shape (data_dim,)
    :return:
        loss of shape (n_communication+1,)
        gradient norm of shape (n_communication+1,)
        w_hist of shape (n_communication+1, data_dim)
    """
    n_clients, n_train_samples, data_dim = train_data.shape

    n_communications = n_train_samples // config.local_steps
    training_data_indice = 0

    w = w_0.copy()

    w_hist = np.zeros((n_communications + 1, data_dim))
    w_hist[0] = w

    # initialize the momentum
    momentum = np.zeros(data_dim)
    # initialize the global update
    global_update = np.zeros((n_clients, data_dim))

    loss = np.zeros(n_communications + 1)
    grad_norm = np.zeros(n_communications + 1)
    loss[0], grad_norm[0] = evaluate(
        test_data,
        test_label,
        w,
        config,
    )

    for t in range(n_communications):
        # broadcast the updated global model
        ws = np.array([w.copy() for _ in range(n_clients)])

        for k in range(config.local_steps):
            # compute the grad
            client_update = config.momentum_coef * grad(
                train_data[:, training_data_indice, :],
                train_label[:, training_data_indice, :],
                ws,
                config.regularization,
                config.lambda_
            ) + (1 - config.momentum_coef) * momentum

            # local update
            ws -= config.local_lr * client_update

            # store the local update
            global_update += client_update

            # update the training data indice
            training_data_indice += 1

        # global update
        momentum = np.mean(global_update, axis=0) / config.local_steps
        w -= config.global_lr * momentum

        # reset the global update to zero
        global_update = np.zeros((n_clients, data_dim))

        # log
        loss[t+1], grad_norm[t+1] = evaluate(
            test_data,
            test_label,
            w,
            config,
        )
        w_hist[t + 1] = w

    return loss, grad_norm, w_hist
